-
Notifications
You must be signed in to change notification settings - Fork 634
[pooling]Fix weight loading for st_projector in embed models #4297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
||
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Layer]: | ||
try: | ||
print("Loading ST Projector...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
try: | ||
file_bytes = get_hf_file_to_dict(file_path, model_config.model, model_config.revision) | ||
if not file_bytes: | ||
print(file_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调试的print删掉吧
act_fn_name = act_fn_name.lower() | ||
|
||
if act_fn_name.startswith("paddle.nn.Layer"): | ||
if act_fn_name.startswith(("paddle.nn.Layer", "torch.nn.modules")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把paddle.nn.Layer删除吧
file_path = f"{model_config.model}/{folder}/{filename}" if folder else filename | ||
|
||
try: | ||
file_bytes = get_hf_file_to_dict(file_path, model_config.model, model_config.revision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要删除?
修复 embedding 模型 st_projector 权重加载
embed 模型中的 st_projector 模块无法正常加载权重。
本 PR 修改
load_weights
方法,确保 st_projector 参数正确加载,从而保证模型初始化正常。影响范围:仅限 embed 模型的 st_projector。